# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class to transform an subgraph into another.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from copy import deepcopy
from functools import partial
from six import iteritems
from six import iterkeys
from six import string_types
from six import StringIO
from tensorflow.contrib.graph_editor import reroute
from tensorflow.contrib.graph_editor import select
from tensorflow.contrib.graph_editor import subgraph
from tensorflow.contrib.graph_editor import util
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.platform import tf_logging as logging


__all__ = [
    "replace_t_with_placeholder_handler",
    "keep_t_if_possible_handler",
    "assign_renamed_collections_handler",
    "transform_op_if_inside_handler",
    "copy_op_handler",
    "Transformer",
    "TransformerInfo",
    "copy",
    "copy_with_input_replacements",
    "graph_replace",
]


def replace_t_with_placeholder_handler(info, t):
  """Transform a tensor into a placeholder tensor.

  This handler is typically used to transform a subgraph input tensor into a
  placeholder.

  Args:
    info: Transform._TmpInfo instance.
    t: tensor whose input must be transformed into a place holder.
  Returns:
    The tensor generated by the newly created place holder.
  """
  with info.graph_.as_default():
    t_ = util.make_placeholder_from_tensor(t, scope=info.scope_)
  return t_


def keep_t_if_possible_handler(info, t):
  """Transform a tensor into itself (identity) if possible.

  This handler transform a tensor into itself if the source and destination
  graph are the same. Otherwise it will create a placeholder.
  This handler is typically used to transform a hidden input tensors.

  Args:
    info: Transform._TmpInfo instance.
    t: tensor whose input must be transformed into a place holder.
  Returns:
    The tensor generated by the newly created place holder.
  """
  if info.graph is info.graph_:
    return t
  else:
    return replace_t_with_placeholder_handler(info, t)


def assign_renamed_collections_handler(info, elem, elem_):
  """Add the transformed elem to the (renamed) collections of elem.

  A collection is renamed only if is not a known key, as described in
  `tf.GraphKeys`.

  Args:
    info: Transform._TmpInfo instance.
    elem: the original element (`tf.Tensor` or `tf.Operation`)
    elem_: the transformed element
  """
  known_collection_names = util.get_predefined_collection_names()
  for name, collection in iteritems(info.collections):
    if elem not in collection:
      continue

    if name in known_collection_names:
      transformed_name = name
    else:
      transformed_name = info.new_name(name)
    info.graph_.add_to_collection(transformed_name, elem_)


def transform_op_if_inside_handler(info, op, keep_if_possible=True):
  """Transform an optional op only if it is inside the subgraph.

  This handler is typically use to handle original op: it is fine to keep them
  if they are inside the subgraph, otherwise they are just ignored.

  Args:
    info: Transform._TmpInfo instance.
    op: the optional op to transform (or ignore).
    keep_if_possible: re-attach to the original op if possible, that is,
      if the source graph and the destination graph are the same.
  Returns:
    The transformed op or None.
  """
  if op in info.sgv.ops:
    return info.transformed_ops[op]
  else:
    if keep_if_possible and info.graph is info.graph_:
      return op
    else:
      return None


def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None):
  """Copy a `tf.Operation`.

  Args:
    info: Transform._TmpInfo instance.
    op: the `tf.Operation` to be copied.
    new_inputs: The new inputs for this op.
    copy_shape: also copy the shape of the tensor
    nodedef_fn: If provided, a function that will be run on the NodeDef
      and should return a mutated NodeDef before a new Operation is created.
      This is useful as certain features cannot be set on the Operation and
      must be modified in NodeDef.

  Returns:
    A `(op, op_outputs)` tuple containing the transformed op and its outputs.
  """
  # The `new_inputs` was added to this function. For compatibility reason,
  # let's raise an error if `new_inputs` is a boolean.
  if isinstance(new_inputs, bool):
    raise TypeError("the `new_inputs` argument must be an iterable.")

  # pylint: disable=protected-access

  # Clone the node def:
  node_def_ = deepcopy(op.node_def)

  # Transform name:
  name_ = info.new_name(op.name)
  name_ = info.graph_.unique_name(name_)
  node_def_.name = name_

  # Mutate NodeDef if requested:
  if nodedef_fn is not None:
    node_def_ = nodedef_fn(node_def_)

  # Copy the other inputs needed for initialization
  output_types_ = op._output_types[:]
  input_types_ = op._input_types[:]

  # Make a copy of the op_def too.
  # Its unique to every _type_ of Operation.
  op_def_ = deepcopy(op.op_def)

  # Initialize a new Operation instance
  op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_,
                         [], input_types_, None, op_def_)

  # copy the shape over
  if copy_shape:
    for t, t_ in zip(op.outputs, op_.outputs):
      t_.set_shape(t.get_shape())

  # Original op cannot be finalised here yet. Because some ops require this
  # attribute to exist, we will create a dummy original_op first and then
  # later finalise it with the actual original_op when all the ops have
  # been copied.
  # TODO(fkp): Stop worrying about _original_op and remove this code?
  if op._original_op:
    op_._original_op = op._original_op

  # Add op to the graph
  info.graph_._add_op(op_)

  return op_, op_.outputs


class TransformerInfo(object):
  """"Contains information about the result of a transform operation."""

  def __init__(self, info):
    """Constructor.

    Args:
      info: an instance of Transformer._TmpInfo containing various internal
        information about the transform operation.
    """
    self._graph = info.graph
    self._scope = info.scope
    self._graph_ = info.graph_
    self._scope_ = info.scope_
    self._transformed_ops = info.transformed_ops
    self._transformed_ts = info.transformed_ts

  def _get_transformed_map(self, top):
    """Return the correct container depending on the type of `top`."""
    if isinstance(top, tf_ops.Operation):
      return self._transformed_ops
    elif isinstance(top, tf_ops.Tensor):
      return self._transformed_ts
    else:
      raise TypeError(
          "Expected a tf.Tensor or a tf.Operation, got a {}".format(
              type(top)))

  def _transformed_elem(self, original_top, missing_fn=None):
    """Return the transformed op/tensor corresponding to the original one.

    Args:
      original_top: the original tensor/operation.
      missing_fn: function handling the case where the counterpart
        cannot be found. By default, None is returned.
    Returns:
      the transformed tensor/operation (or None if no match is found).
    """
    transformed_map = self._get_transformed_map(original_top)
    if isinstance(original_top, string_types):
      for original, transformed in iteritems(transformed_map):
        if original.name == original_top:
          return transformed
      return None if missing_fn is None else missing_fn(original_top)
    else:
      if original_top not in transformed_map:
        return None if missing_fn is None else missing_fn(original_top)
      return transformed_map[original_top]

  def _original_elem(self, transformed_top, missing_fn=None):
    """Return the original op/tensor corresponding to the transformed one.

    Args:
      transformed_top: the transformed tensor/operation.
      missing_fn: function handling the case where the counterpart
        cannot be found. By default, None is returned.
    Returns:
      the original tensor/operation (or None if no match is found).
    """
    transformed_map = self._get_transformed_map(transformed_top)
    if isinstance(transformed_top, string_types):
      finder = lambda transformed: transformed.name == transformed_top
    else:
      finder = lambda transformed: transformed == transformed_top
    for original, transformed in iteritems(transformed_map):
      if finder(transformed):
        return original
    return None if missing_fn is None else missing_fn(transformed_top)

  def transformed(self, original, missing_fn=None):
    """Return the transformed op/tensor corresponding to the original one.

    Note that the output of this function mimics the hierarchy
    of its input argument `original`.
    Given an iterable, it returns a list. Given an operation or a tensor,
    it will return an operation or a tensor.

    Args:
      original: the original tensor/operation.
      missing_fn: function handling the case where the counterpart
        cannot be found. By default, None is returned.
    Returns:
      the transformed tensor/operation (or None if no match is found).
    """
    transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn)
    return util.transform_tree(original, transformed_elem)

  def original(self, transformed, missing_fn=None):
    """Return the original op/tensor corresponding to the transformed one.

    Note that the output of this function mimics the hierarchy
    of its input argument `transformed`.
    Given an iterable, it returns a list. Given an operation or a tensor,
    it will return an operation or a tensor.

    Args:
      transformed: the transformed tensor/operation.
      missing_fn: function handling the case where the counterpart
        cannot be found. By default, None is returned.
    Returns:
      the original tensor/operation (or None if no match is found).
    """
    original_elem = partial(self._original_elem, missing_fn=missing_fn)
    return util.transform_tree(transformed, original_elem)

  def __str__(self):
    res = StringIO()
    print("Transform result info:", file=res)
    if self._graph == self._graph_:
      in_place_str = "" if self._scope_ else " IN-PLACE"
      print("  Within graph[{}]{}".format(
          id(self._graph), in_place_str), file=res)
    else:
      print("  graph[{}] => graph[{}]".format(
          id(self._graph), id(self._graph_)), file=res)
    if self._scope:
      print("  Relative to source scope: {}".format(self._scope), file=res)
    if self._scope_:
      print("  Scope destination: {}".format(self._scope_), file=res)
    print("Operations mapping:", file=res)
    for op, op_ in iteritems(self._transformed_ops):
      print("  {} => {}".format(op.name, op_.name), file=res)
    return res.getvalue()


class _TmpInfo(object):
  """Transformer temporary data.

  An instance of this class holds all the information relevant to a call
  to a transformer instance (that is, a call to __call__). An instance
  is created for the life-time of the __call__ function and is passed as
  argument to the handlers.
  """

  def __init__(self, sgv, dst_graph, dst_scope, src_scope):
    self.sgv = sgv
    self.sgv_inputs_set = frozenset(sgv.inputs)
    self.ops = frozenset(sgv.ops)
    self.control_outputs = util.ControlOutputs(sgv.graph)
    self.graph = sgv.graph
    self.scope = src_scope
    self.graph_ = dst_graph
    self.scope_ = dst_scope
    self.transformed_ops = {}
    self.transformed_ts = {}
    self.collections = dict((key, self.graph.get_collection(key))
                            for key in self.graph.get_all_collection_keys())
    self.cyclic_ops = []
    self.transform_original_op_handler = transform_op_if_inside_handler
    # The graph is transformed op by op, in the same order the original ops
    # were created. However, this is sometimes not possible due to cycles
    # (i.e. while loops). So when the transformer creates a new op whose
    # inputs do not exist yet, temporary placeholders are created and stored
    # in this `tmp_cyclic_ts` container. During a second pass,
    # those temporary tensors are replaced by the proper transformed tensors
    # (see the function `_finalize_cycles`).
    self.tmp_cyclic_ts = []

  def new_name(self, name):
    """Compute a destination name from a source name.

    Args:
      name: the name to be "transformed".
    Returns:
      The transformed name.
    Raises:
      ValueError: if the source scope is used (that is, not an empty string)
        and the source name does not belong to the source scope.
    """
    scope = self.scope
    if not name.startswith(scope):
      raise ValueError("{} does not belong to source scope: {}.".format(
          name, scope))
    rel_name = name[len(scope):]
    name_ = self.scope_ + rel_name
    return name_


class Transformer(object):
  """Transform a subgraph into another one.

  By default, the constructor create a transform which copy a subgraph and
  replaces inputs with placeholders. This behavior can be modified by changing
  the handlers.
  """

  def __init__(self):
    """Transformer constructor.

    The following members can be modified:
    transform_op_handler: handle the transformation of a `tf.Operation`.
      This handler defaults to a simple copy.
    assign_collections_handler: handle the assignment of collections.
      This handler defaults to assigning new collections created under the
      given name-scope.
    transform_external_input_handler: handle the transform of the inputs to
      the given subgraph. This handler defaults to creating placeholders
      instead of the ops just before the input tensors of the subgraph.
    transform_external_hidden_input_handler: handle the transform of the
      hidden inputs of the subgraph, that is, the inputs which are not listed
      in sgv.inputs. This handler defaults to a transform which keep the same
      input if the source and destination graphs are the same, otherwise
      use placeholders.
    transform_original_op_handler: handle the transform of original_op. This
      handler defaults to transforming original_op only if they are in the
      subgraph, otherwise they are ignored.
    """

    # handlers
    self.transform_op_handler = copy_op_handler
    self.transform_control_input_handler = transform_op_if_inside_handler
    self.assign_collections_handler = assign_renamed_collections_handler
    self.transform_external_input_handler = replace_t_with_placeholder_handler
    self.transform_external_hidden_input_handler = keep_t_if_possible_handler
    self.transform_original_op_handler = transform_op_if_inside_handler

  def __call__(self,
               sgv,
               dst_graph,
               dst_scope,
               src_scope="",
               reuse_dst_scope=False):
    """Execute the transformation.

    Args:
      sgv: the source subgraph-view.
      dst_graph: the destination graph.
      dst_scope: the destination scope.
      src_scope: the source scope, which specify the path from which the
        relative path of the transformed nodes are computed. For instance, if
        src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a
        relative path of x/y and will be transformed into b/x/y.
      reuse_dst_scope: if True the dst_scope is re-used if it already exists.
        Otherwise, the scope is given a unique name based on the one given
        by appending an underscore followed by a digit (default).
    Returns:
      A tuple `(sgv, info)` where:
        `sgv` is the transformed subgraph view;
        `info` is an instance of TransformerInfo containing
        information about the transform, including mapping between
        original and transformed tensors and operations.
    Raises:
      ValueError: if the arguments are invalid.
    """
    sgv = subgraph.make_view(sgv)
    if not isinstance(dst_graph, tf_ops.Graph):
      raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))

    src_scope = util.scope_finalize(src_scope)
    dst_scope = util.scope_finalize(dst_scope)

    # Potentially create new scope if reuse_dst_scope is False
    if dst_scope and not reuse_dst_scope:
      dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1]))

    # Create temporary info used during this transform call
    info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope)

    self._copy_ops(info)
    self._finalize_cycles(info)
    self._connect_control_inputs(info)

    # Compute information about the transformation
    res_info = TransformerInfo(info)
    sgv_ = self._transform_sgv(info, sgv)
    return sgv_, res_info

  def _copy_ops(self, info):
    """Copy ops without connecting them."""
    sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id)  # pylint: disable=protected-access
    for op in sorted_ops:
      new_inputs = [self._transformed_t(info, t, op) for t in op.inputs]
      op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs)
      if op is op_:
        raise ValueError("In-place transformation not allowed.")

      # Process op.
      info.transformed_ops[op] = op_
      self.assign_collections_handler(info, op, op_)

      # Process output tensors.
      for op_output, op_output_ in zip(op.outputs, op_outputs_):
        info.transformed_ts[op_output] = op_output_
        self.assign_collections_handler(info, op_output, op_output_)

  def _finalize_cycles(self, info):
    """Reconnects the cyclic tensors."""
    for t, tmp_t_, consumer_op in info.tmp_cyclic_ts:
      if t not in info.transformed_ts:
        raise ValueError("The tensor {} should be transformed by now.".format(
            t.name))
      if consumer_op not in info.transformed_ops:
        raise ValueError("The op {} should be transformed by now.".format(
            consumer_op.name))
      t_ = info.transformed_ts[t]
      consumer_op_ = info.transformed_ops[consumer_op]
      t_index_ = list(consumer_op_.inputs).index(tmp_t_)
      consumer_op_._update_input(t_index_, t_, update_dtype=False)  # pylint: disable=protected-access

  def _connect_control_inputs(self, info):
    """Connect the previously copied ops."""
    for op in info.sgv.ops:
      logging.debug("Connecting control inputs of op: %s", op.name)
      op_ = info.transformed_ops[op]

      # Finalize original op.
      # TODO(fkp): Stop worrying about _original_op and remove this code?
      # pylint: disable=protected-access
      if op._original_op:
        original_op = self.transform_original_op_handler(info, op._original_op)
        if original_op is None:
          logging.debug("Could not find original op for: %s", op_.name)
        else:
          op_._original_op = original_op
      # pylint: enable=protected-access

      # Finalize control inputs:
      control_inputs_ = [self.transform_control_input_handler(info, ci)
                         for ci in op.control_inputs]
      control_inputs_ = [ci for ci in control_inputs_ if ci is not None]
      reroute.add_control_inputs(op_, control_inputs_)

  def _transform_sgv(self, info, sgv):
    """Transform a subgraph view.

    For convenience, a transform operation returns a subgraph view of the
    transformed graph.

    Args:
      info: Temporary information for this transorfm call.
      sgv: the subgraph to be transformed.
    Returns:
      The transformed subgraph.
    """
    ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)]
    sgv_ = subgraph.SubGraphView(ops_)
    sgv_inputs_ = sgv_.inputs
    sgv_outputs_ = sgv_.outputs

    # re-order inputs
    input_map_ = []
    for input_t in sgv.inputs:
      if input_t not in info.transformed_ts:
        continue
      input_t_ = info.transformed_ts[input_t]
      if input_t_ not in sgv_inputs_:
        continue
      input_t_index_ = sgv_.input_index(input_t_)
      input_map_.append(input_t_index_)

    # re-order outputs
    output_map_ = []
    for output_t in sgv.outputs:
      if output_t not in info.transformed_ts:
        continue
      output_t_ = info.transformed_ts[output_t]
      if output_t_ not in sgv_outputs_:
        continue
      output_t_index_ = sgv_.output_index(output_t_)
      output_map_.append(output_t_index_)

    return sgv_.remap(input_map_, output_map_)

  def _transformed_t(self, info, t, consumer_op):
    """Return tre transformed tensor of `t`."""
    if t in info.transformed_ts:
      # If op is in the subgraph, just return its transformed counterpart.
      return info.transformed_ts[t]

    if t in info.sgv_inputs_set:
      # `t` is an input of the subgraph.
      return self.transform_external_input_handler(info, t)
    elif t.op in info.ops:
      # `t` is an internal tensor but is not transformed yet because it
      # belongs to a graph cycle.
      logging.debug("Cyclic tensor: t.name = %s", t.name)
      # Try to find an existing tensor we can use for now,
      # otherwise create one. We'll rewire this later.
      if consumer_op.type == "Merge":
        first_input = consumer_op.inputs[0]
        tmp_t_ = self._transformed_t(info, first_input, consumer_op)
      elif t.op.type == "Enter":
        enter_input = t.op.inputs[0]
        tmp_t_ = self._transformed_t(info, enter_input, consumer_op)
      else:
        with info.graph_.as_default():
          tmp_t_ = util.make_placeholder_from_tensor(t, scope=info.scope_,
                                                     prefix="geph_tmp")
        logging.debug("Created temporary placeholder: %s.", tmp_t_.name)
      # Register as temporary and return.
      info.tmp_cyclic_ts.append((t, tmp_t_, consumer_op))
      return tmp_t_
    else:
      # `t` is a hidden input of the subgraph.
      return self.transform_external_hidden_input_handler(info, t)


def copy(sgv, dst_graph=None, dst_scope="", src_scope="",
         reuse_dst_scope=False):
  """Copy a subgraph.

  Args:
    sgv: the source subgraph-view. This argument is converted to a subgraph
      using the same rules than the function subgraph.make_view.
    dst_graph: the destination graph.
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A tuple `(sgv, info)` where:
      `sgv` is the transformed subgraph view;
      `info` is an instance of TransformerInfo containing
      information about the transform, including mapping between
      original and transformed tensors and operations.
  Raises:
    TypeError: if `dst_graph` is not a `tf.Graph`.
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
  sgv = subgraph.make_view(sgv)
  if dst_graph is None:
    dst_graph = sgv.graph
  if not isinstance(dst_graph, tf_ops.Graph):
    raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))

  copier = Transformer()
  return copier(
      sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)


def copy_with_input_replacements(sgv, replacement_ts,
                                 dst_graph=None, dst_scope="", src_scope="",
                                 reuse_dst_scope=False):
  """Copy a subgraph, replacing some of its inputs.

  Note a replacement only happens if the tensor to be replaced
  is an input of the given subgraph. The inputs of a subgraph can
  be queried using sgv.inputs.

  Args:
    sgv: the source subgraph-view. This argument is converted to a subgraph
      using the same rules as the function subgraph.make_view.
    replacement_ts: dictionary mapping from original tensors to the
      replaced one.
    dst_graph: the destination graph.
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A tuple `(sgv, info)` where:
      `sgv` is the transformed subgraph view;
      `info` is an instance of TransformerInfo containing
      information about the transform, including mapping between
      original and transformed tensors and operations.
  Raises:
    TypeError: if dst_graph is not a tf.Graph.
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules as the function subgraph.make_view.
  """
  sgv = subgraph.make_view(sgv)
  if dst_graph is None:
    dst_graph = sgv.graph
  if not isinstance(dst_graph, tf_ops.Graph):
    raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))

  copier = Transformer()
  # Replace tensor if possible.
  def replace_t_with_replacement_handler(info, t):
    if t in replacement_ts:
      return replacement_ts[t]
    else:
      return keep_t_if_possible_handler(info, t)
  copier.transform_external_input_handler = replace_t_with_replacement_handler
  return copier(
      sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)


def _add_control_flow_ops(ops, control_ios):
  """Complete `ops` so that the tranformed graph is valid.

  Partially copying a graph can lead to a malformed graph. For instance,
  copying half of a while construct is likely to result in an invalid graph.
  This function attempts to add missing ops so that the transformation result
  in a valid graph.

  Args:
    ops: list of ops (modifed in-place).
    control_ios: object created by a call to `util.ControlOutputs`.
  """
  # Find while contexts.
  control_flow_contexts = set()
  for op in ops:
    cfc = op._control_flow_context  # pylint: disable=protected-access
    if cfc:
      control_flow_contexts.add(cfc)
  # Find new ops.
  new_ops = []
  for cfc in control_flow_contexts:
    if cfc.IsWhileContext():
      new_ops += select.get_walks_intersection_ops(
          [enter_t.op for enter_t in cfc.loop_enters],
          [exit_t.op for exit_t in cfc.loop_exits],
          control_ios=control_ios)
  # Add new ops.
  new_ops_set = set(new_ops)
  ops_set = frozenset(ops)
  for op in new_ops_set:
    if op not in ops_set:
      ops.append(op)


def graph_replace(target_ts, replacement_ts, dst_scope="",
                  src_scope="", reuse_dst_scope=False):
  """Create a new graph which compute the targets from the replaced Tensors.

  Args:
    target_ts: a single tf.Tensor or an iterable of tf.Tensor.
    replacement_ts: dictionary mapping from original tensors to replaced tensors
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A single tf.Tensor or a list of target tf.Tensor, depending on
    the type of the input argument `target_ts`.
    The returned tensors are recomputed using the tensors from replacement_ts.
  Raises:
    ValueError: if the targets are not connected to replacement_ts.
  """
  # Identify operations in the graph that will change.
  # Start forward walk at Tensors that will be replaced, and
  # backward walk at the target output Tensors.
  flatten_target_ts = util.flatten_tree(target_ts)
  # Construct the forward control dependencies edges so that
  # the get_walks_intersection_ops can also traverse the
  # control dependencies.
  graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor))
  control_ios = util.ControlOutputs(graph)
  ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)),
                                          flatten_target_ts,
                                          control_ios=control_ios)
  if not ops:
    raise ValueError("Targets and replacements are not connected!")

  # Complete ops to avoid malformed control flow.
  # TODO(fkp): Consider moving this function deeper (in the transformer?).
  _add_control_flow_ops(ops, control_ios)

  # Create a copy of the relevant subgraph
  unused_sgv_, info = copy_with_input_replacements(
      ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope)
  # Return the transformed targets but keep the original if the transformed
  # counterpart cannot be found
  missing_fn = lambda original_t: original_t
  return info.transformed(target_ts, missing_fn)
